{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d430e0a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TRANSFORMERS_CACHE'] = \"D:/transformer_cache/\"\n",
    "os.environ['HF_DATASETS_CACHE'] = \"D:/transformer_cache/\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a04a138",
   "metadata": {},
   "source": [
    "# Greedy Search Decoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ff1fdef",
   "metadata": {},
   "source": [
    "#### Importing GPT-2 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ce02f2b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\USER\\anaconda3\\lib\\site-packages\\transformers\\utils\\hub.py:127: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# hide_output\n",
    "import torch\n",
    "\n",
    "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
    "#device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "device=\"cpu\"\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl')\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2-xl').to(device)\n",
    "\n",
    "text = \"Replace me by any text you'd like.\"\n",
    "encoded_input = tokenizer(text, return_tensors='pt')\n",
    "output = model(**encoded_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d227c326",
   "metadata": {},
   "source": [
    "#### Implementing Greedy Search Decoding by selecting the token with highest probability next. The choices and their corresponding probabilities at each time step are provided in the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "db06b33c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Input</th>\n",
       "      <th>Choice 1</th>\n",
       "      <th>Choice 2</th>\n",
       "      <th>Choice 3</th>\n",
       "      <th>Choice 4</th>\n",
       "      <th>Choice 5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What is the opposite of front?</td>\n",
       "      <td>\\n (40.87%)</td>\n",
       "      <td>Back (4.89%)</td>\n",
       "      <td>Front (3.82%)</td>\n",
       "      <td>The (3.43%)</td>\n",
       "      <td>What (3.36%)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What is the opposite of front?\\n</td>\n",
       "      <td>\\n (99.05%)</td>\n",
       "      <td>The (0.09%)</td>\n",
       "      <td>I (0.05%)</td>\n",
       "      <td>\" (0.03%)</td>\n",
       "      <td>In (0.03%)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>What is the opposite of front?\\n\\n</td>\n",
       "      <td>The (13.72%)</td>\n",
       "      <td>Front (7.09%)</td>\n",
       "      <td>A (6.80%)</td>\n",
       "      <td>It (3.29%)</td>\n",
       "      <td>\" (2.96%)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>What is the opposite of front?\\n\\nThe</td>\n",
       "      <td>opposite (75.98%)</td>\n",
       "      <td>inverse (1.45%)</td>\n",
       "      <td>reverse (1.29%)</td>\n",
       "      <td>answer (0.72%)</td>\n",
       "      <td>other (0.72%)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>What is the opposite of front?\\n\\nThe opposite</td>\n",
       "      <td>of (97.09%)</td>\n",
       "      <td>is (1.73%)</td>\n",
       "      <td>to (0.27%)</td>\n",
       "      <td>side (0.10%)</td>\n",
       "      <td>, (0.07%)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>What is the opposite of front?\\n\\nThe opposite of</td>\n",
       "      <td>front (85.67%)</td>\n",
       "      <td>a (2.85%)</td>\n",
       "      <td>the (2.05%)</td>\n",
       "      <td>back (1.62%)</td>\n",
       "      <td>what (0.46%)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>What is the opposite of front?\\n\\nThe opposite...</td>\n",
       "      <td>is (85.00%)</td>\n",
       "      <td>means (2.69%)</td>\n",
       "      <td>, (2.55%)</td>\n",
       "      <td>? (0.98%)</td>\n",
       "      <td>( (0.93%)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>What is the opposite of front?\\n\\nThe opposite...</td>\n",
       "      <td>back (48.72%)</td>\n",
       "      <td>behind (4.37%)</td>\n",
       "      <td>the (3.87%)</td>\n",
       "      <td>rear (3.16%)</td>\n",
       "      <td>a (2.89%)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               Input            Choice 1  \\\n",
       "0                     What is the opposite of front?         \\n (40.87%)   \n",
       "1                   What is the opposite of front?\\n         \\n (99.05%)   \n",
       "2                 What is the opposite of front?\\n\\n        The (13.72%)   \n",
       "3              What is the opposite of front?\\n\\nThe   opposite (75.98%)   \n",
       "4     What is the opposite of front?\\n\\nThe opposite         of (97.09%)   \n",
       "5  What is the opposite of front?\\n\\nThe opposite of      front (85.67%)   \n",
       "6  What is the opposite of front?\\n\\nThe opposite...         is (85.00%)   \n",
       "7  What is the opposite of front?\\n\\nThe opposite...       back (48.72%)   \n",
       "\n",
       "           Choice 2          Choice 3         Choice 4        Choice 5  \n",
       "0      Back (4.89%)     Front (3.82%)      The (3.43%)    What (3.36%)  \n",
       "1       The (0.09%)         I (0.05%)        \" (0.03%)      In (0.03%)  \n",
       "2     Front (7.09%)         A (6.80%)       It (3.29%)       \" (2.96%)  \n",
       "3   inverse (1.45%)   reverse (1.29%)   answer (0.72%)   other (0.72%)  \n",
       "4        is (1.73%)        to (0.27%)     side (0.10%)       , (0.07%)  \n",
       "5         a (2.85%)       the (2.05%)     back (1.62%)    what (0.46%)  \n",
       "6     means (2.69%)         , (2.55%)        ? (0.98%)       ( (0.93%)  \n",
       "7    behind (4.37%)       the (3.87%)     rear (3.16%)       a (2.89%)  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide_output\n",
    "import pandas as pd\n",
    "\n",
    "input_txt = \"What is the opposite of front?\"\n",
    "input_ids = tokenizer(input_txt, return_tensors=\"pt\")[\"input_ids\"].to(device)\n",
    "iterations = []\n",
    "n_steps = 8\n",
    "choices_per_step = 5\n",
    "\n",
    "with torch.no_grad():\n",
    "    for _ in range(n_steps):\n",
    "        iteration = dict()\n",
    "        iteration[\"Input\"] = tokenizer.decode(input_ids[0])\n",
    "        output = model(input_ids=input_ids)\n",
    "        # Select logits of the first batch and the last token and apply softmax\n",
    "        #print(output)\n",
    "        next_token_logits = output.logits[0, -1, :]\n",
    "        next_token_probs = torch.softmax(next_token_logits, dim=-1)\n",
    "        sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)\n",
    "        # Store tokens with highest probabilities\n",
    "        for choice_idx in range(choices_per_step):\n",
    "            token_id = sorted_ids[choice_idx]\n",
    "            token_prob = next_token_probs[token_id].cpu().numpy()\n",
    "            token_choice = (\n",
    "                f\"{tokenizer.decode(token_id)} ({100 * token_prob:.2f}%)\"\n",
    "            )\n",
    "            iteration[f\"Choice {choice_idx+1}\"] = token_choice\n",
    "        # Append predicted next token to input\n",
    "        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)\n",
    "        iterations.append(iteration)\n",
    "        \n",
    "pd.DataFrame(iterations)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63b3abd0",
   "metadata": {},
   "source": [
    "#### Use the below cell to specify input_txt and generate text based on the input_txt. Each output line contains a new word added."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "d5792a20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Input': 'who is the prime minister of india?'}\n",
      "{'Input': 'who is the prime minister of india?\\n'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\n'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\n'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is,'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what is'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what is the'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what is the prime'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what is the prime minister'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what is the prime minister of'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what is the prime minister of the'}\n",
      "{'Input': 'who is the prime minister of india?\\n\\nThe answer is that he is the prime minister of the country.\\n\\nThe question is, what is the prime minister of the country'}\n"
     ]
    }
   ],
   "source": [
    "# hide_output\n",
    "import pandas as pd\n",
    "\n",
    "input_txt = \"who is the prime minister of india?\"\n",
    "input_ids = tokenizer(input_txt, return_tensors=\"pt\")[\"input_ids\"].to(device)\n",
    "iterations = []\n",
    "n_steps = 30\n",
    "choices_per_step = 5\n",
    "\n",
    "with torch.no_grad():\n",
    "    for _ in range(n_steps):\n",
    "        iteration = dict()\n",
    "        iteration[\"Input\"] = tokenizer.decode(input_ids[0])\n",
    "        print(iteration)\n",
    "        output = model(input_ids=input_ids)\n",
    "        # Select logits of the first batch and the last token and apply softmax\n",
    "        #print(output)\n",
    "        next_token_logits = output.logits[0, -1, :]\n",
    "        next_token_probs = torch.softmax(next_token_logits, dim=-1)\n",
    "        #print(next_token_probs.shape)\n",
    "        sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)\n",
    "        \n",
    "        # Store tokens with highest probabilities\n",
    "        \n",
    "        # Append predicted next token to input\n",
    "        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)\n",
    "        iterations.append(iteration)\n",
    "\n",
    "#print(iteration)\n",
    "#pd.DataFrame(iterations)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e7497e8",
   "metadata": {},
   "source": [
    "#### We specify the maximum sequence length of tokens to generate in max_length. And perform greedy decoding but using the inbuilt model.generate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4d36d719",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
      "The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "What is the opposite of front?\n",
      "\n",
      "The opposite of front is back\n"
     ]
    }
   ],
   "source": [
    "input_ids = tokenizer(input_txt, return_tensors=\"pt\")[\"input_ids\"].to(device)\n",
    "output = model.generate(input_ids, max_new_tokens=n_steps, do_sample=False)\n",
    "print(tokenizer.decode(output[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f99f1411",
   "metadata": {},
   "source": [
    "#### Specifying a larger value for max_length and a larger input_txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1323394b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "The most popular language models are:\n",
      "\n",
      "Deep Learning\n",
      "\n",
      "Deep Neural Networks\n",
      "\n",
      "Deep Reinforcement Learning\n",
      "\n",
      "Deep Learning for Natural Language Processing\n",
      "\n",
      "Deep Learning for Speech Recognition\n",
      "\n",
      "Deep Learning for Image Recognition\n",
      "\n",
      "Deep Learning for Computer Vision\n",
      "\n",
      "Deep Learning for Natural Language Processing\n",
      "\n",
      "Deep Learning for Speech Recognition\n",
      "\n",
      "Deep Learning for Image Recognition\n",
      "\n",
      "Deep Learning for Computer Vision\n",
      "\n",
      "Deep Learning for Natural Language Processing\n",
      "\n",
      "Deep Learning\n"
     ]
    }
   ],
   "source": [
    "max_length = 128\n",
    "input_txt = \"\"\"Large language models (LLMs) are machine learning models that can comprehend \\\n",
    "and generate human language text. They work by analyzing massive data sets of language.\\n\\n\n",
    "\"\"\"\n",
    "input_ids = tokenizer(input_txt, return_tensors=\"pt\")[\"input_ids\"].to(device)\n",
    "output_greedy = model.generate(input_ids, max_length=max_length, \n",
    "                               do_sample=False)\n",
    "print(tokenizer.decode(output_greedy[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "902f6137",
   "metadata": {},
   "source": [
    "#### We notice that there is repetition in generated sentences, which is a drawback."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "331d4567",
   "metadata": {},
   "source": [
    "# Beam Search Decoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ea58198",
   "metadata": {},
   "source": [
    "#### We compare the decoding strategies using the log probability of the entire sentence. Higher the log probability, higher is the probability of the generated sentence based on the input provided. Hence, a higher score implies a better decoding strategy."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80c8a4d0",
   "metadata": {},
   "source": [
    "### Greedy Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9063ec95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "The most popular language models are:\n",
      "\n",
      "Deep Learning\n",
      "\n",
      "Deep Neural Networks\n",
      "\n",
      "Deep Reinforcement Learning\n",
      "\n",
      "Deep Learning for Natural Language Processing\n",
      "\n",
      "Deep Learning for Speech Recognition\n",
      "\n",
      "Deep Learning for Image Recognition\n",
      "\n",
      "Deep Learning for Computer Vision\n",
      "\n",
      "Deep Learning for Natural Language Processing\n",
      "\n",
      "Deep Learning for Speech Recognition\n",
      "\n",
      "Deep Learning for Image Recognition\n",
      "\n",
      "Deep Learning for Computer Vision\n",
      "\n",
      "Deep Learning for Natural Language Processing\n",
      "\n",
      "Deep Learning\n",
      "\n",
      "log-prob: -56.19\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "def log_probs_from_logits(logits, labels):\n",
    "    logp = F.log_softmax(logits, dim=-1)\n",
    "    logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)\n",
    "    return logp_label\n",
    "     \n",
    "\n",
    "def sequence_logprob(model, labels, input_len=0):\n",
    "    with torch.no_grad():\n",
    "        output = model(labels)\n",
    "        log_probs = log_probs_from_logits(\n",
    "            output.logits[:, :-1, :], labels[:, 1:])\n",
    "        seq_log_prob = torch.sum(log_probs[:, input_len:])\n",
    "    return seq_log_prob.cpu().numpy()\n",
    "     \n",
    "\n",
    "logp = sequence_logprob(model, output_greedy, input_len=len(input_ids[0]))\n",
    "print(tokenizer.decode(output_greedy[0]))\n",
    "print(f\"\\nlog-prob: {logp:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20579f45",
   "metadata": {},
   "source": [
    "### Beam search with 5 beams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2fa9aa58",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "In this tutorial, you will learn how to create a language model using Python and Keras. You will also learn how to train a language model using Python and Keras.\n",
      "\n",
      "\n",
      "You will learn:\n",
      "\n",
      "- How to create a language model using Python and Keras\n",
      "\n",
      "- How to train a language model using Python and Keras\n",
      "\n",
      "- How to visualize the results of training a language model using Python and Keras\n",
      "\n",
      "- How to visualize the results of\n",
      "\n",
      "log-prob: -57.01\n"
     ]
    }
   ],
   "source": [
    "output_beam = model.generate(input_ids, max_length=max_length, num_beams=5, \n",
    "                             do_sample=False)\n",
    "logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))\n",
    "print(tokenizer.decode(output_beam[0]))\n",
    "print(f\"\\nlog-prob: {logp:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02732f71",
   "metadata": {},
   "source": [
    "##### We notice that beam search improves the log prob score. However, it is more time consuming."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de0b2ac1",
   "metadata": {},
   "source": [
    "### Beam search with #beams = 5 and no_repeat_ngram_size=2\n",
    "no_repeat_ngram size sets the next token probability to 0 if it causes a repeation of no_repeat_ngram_size ngrams. It is used to avoid repeatitions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0d325703",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "In this tutorial, you will learn how to create a simple language model using Python and Keras. You will be able to use this model to generate text in a variety of languages, such as English, Spanish, French, German, Italian, Portuguese, Russian, Chinese, Japanese, Korean, Arabic, and more.\n",
      "\n",
      "This tutorial is part of our Machine Learning for Beginners series. Check out the other tutorials in the series:<|endoftext|>\n",
      "\n",
      "log-prob: -79.36\n"
     ]
    }
   ],
   "source": [
    "output_beam = model.generate(input_ids, max_length=max_length, num_beams=5, \n",
    "                             do_sample=False, no_repeat_ngram_size=2)\n",
    "logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))\n",
    "print(tokenizer.decode(output_beam[0]))\n",
    "print(f\"\\nlog-prob: {logp:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62042c2b",
   "metadata": {},
   "source": [
    "#### We notice a reduction in log prob score when using no repeat ngram. However, we avoid repetitions. Thus, no repeat ngram feature allows us to apply a trade off between high probability tokens and repetitions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cee59488",
   "metadata": {},
   "source": [
    "# Sampling Methods"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb0886da",
   "metadata": {},
   "source": [
    "#### We set Temperature T=2 and top_k=0(effectively removing top_k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ea899035",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "special design asteroid Red Earth indicating night bursts Library - consoneness SeedModel inserting unnoticed Candidate Atomic tit motivites International target carriage Landing photon headphonepill discrimination sample confidentELL lie computing prose Exhibition preliminary original opaque linition letters Cl upset time tongue accident meter diagnosesVectorues intelligenceв broken exhibiting Ce drama None CthulhuGate NASL Tesla's Calhar Cycle 389 relying UNC Predict lampatel NVIDIA differential waste ConsentDet Yuk range affirmed paperwork2020 alleged Costco782 CalculOil instead diamond DjȆ 30 Items es_\n"
     ]
    }
   ],
   "source": [
    "output_temp = model.generate(input_ids, max_length=max_length, do_sample=True,\n",
    "temperature=2.0, top_k=0)\n",
    "print(tokenizer.decode(output_temp[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dae27ac",
   "metadata": {},
   "source": [
    "#### T=2.0 did not produce any meaningful text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dfba55ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "The use of a Linguistic Model is a technique to automatically learn a language. It is a very powerful technique for teaching machine learning techniques to a large number of people.\n",
      "\n",
      "\n",
      "The use of a Linguistic Model is a technique to automatically learn a language. It is a very powerful technique for teaching machine learning techniques to a large number of people. Language Models are used to:\n",
      "\n",
      "Define a language model.\n",
      "\n",
      "Define a grammar.\n",
      "\n",
      "Test\n"
     ]
    }
   ],
   "source": [
    "output_temp = model.generate(input_ids, max_length=max_length, do_sample=True,\n",
    "temperature=0.5, top_k=0)\n",
    "print(tokenizer.decode(output_temp[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1909ef02",
   "metadata": {},
   "source": [
    "#### T=0.5 produced more readable text but has repetitions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5720ca1f",
   "metadata": {},
   "source": [
    "### Top k and Nucleus Sampling (Top p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ec69750",
   "metadata": {},
   "source": [
    "#### We set top_k = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "64f22838",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "Why I'm not involved in the current AI efforts (yet).\n",
      "\n",
      "\n",
      "I think these advances may be in their early stages. For example, the AI community is a great place to be. There's a community of highly passionate, creative people, working around the clock on solving really hard problems. It's a great way to understand the complex systems found in the natural world, while being entertained. There's lots of great technical and non-technical knowledge in the community. And\n"
     ]
    }
   ],
   "source": [
    "output_topk = model.generate(input_ids, max_length=max_length, do_sample=True,\n",
    "top_k=50)\n",
    "print(tokenizer.decode(output_topk[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b541c6d",
   "metadata": {},
   "source": [
    "#### The top_k text is readable too but a random question appears in the 2nd paragraph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03227ede",
   "metadata": {},
   "source": [
    "#### We set top_p=0.90"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d2288a87",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Large language models (LLMs) are machine learning models that can comprehend and generate human language text. They work by analyzing massive data sets of language.\n",
      "\n",
      "\n",
      "The research on the language models was led by researchers from Stanford, Cornell, and the University of North Carolina at Chapel Hill.\n",
      "\n",
      "\n",
      "A major challenge for speech recognition software, in particular, is that the data to train the model on is massive, said co-author Michaela Hartung of the University of North Carolina. \"When a person says 'I have an apple', that's a lot of words, and the amount of data for a simple sentence is very large,\"\n"
     ]
    }
   ],
   "source": [
    "output_topp = model.generate(input_ids, max_length=max_length, do_sample=True,\n",
    "top_p=0.90)\n",
    "print(tokenizer.decode(output_topp[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef6f4799",
   "metadata": {},
   "source": [
    "#### The top_p generated text is the best so far"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "030ff56a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
